{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple CNN\n",
    "\n",
    "In this assignment, you need to build a simple CNN to classify images in CIFAR-10 dataset.\n",
    "\n",
    "![cifar10](https://pytorch.org/tutorials/_images/cifar10.png)\n",
    "\n",
    "The framework is given below. What you need to do is to build your net. Then tune the hyperparameters, train the NN, and achieve **60%+** accuracy. (about 1h training on CPU is enough)\n",
    "\n",
    "You can also refer to the [PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some hyperparameters. They are maybe not the best, you need to tune, especially the learning rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 1e-3\n",
    "NUM_EPOCHS = 20\n",
    "BATCH_SIZE = 32 # Mini-batch size\n",
    "DEVICE = torch.device('cuda')\n",
    "CONFIG = {\"lr\": LEARNING_RATE,\n",
    "          \"num_epochs\": NUM_EPOCHS,\n",
    "          \"batch_size\": BATCH_SIZE}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO**: Define your NN here.\n",
    "\n",
    "Some reference networks:\n",
    "* LeNet\n",
    "* AlexNet\n",
    "* VGG\n",
    "* GoogleLeNet (maybe to large for your PC to train)\n",
    "* ResNet (too large)\n",
    "\n",
    "Ref: [Overview of CNN Development - Zhihu](https://zhuanlan.zhihu.com/p/66215918)\n",
    "\n",
    "Do not directly copy others' code. Check these networks' structure and write code yourself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    \"\"\"\n",
    "    TODO: Implementation of a simple Convolutional neural network.\n",
    "    HINT: You can refer to several famous CNNs, like LeNet5, VGG16, ResNet\n",
    "        Actually you cannot directly copy the code in the demo,\n",
    "        since you need to recalculate the shape of the tensor.\n",
    "    \"\"\"\n",
    "    \"\"\"YOUR CODE HERE\"\"\"\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        pass # your code\n",
    "        \n",
    "    def forward(self, x):\n",
    "        pass # yuor code\n",
    "    \"\"\"END OF YOUR CODE\"\"\"\n",
    "\n",
    "model = net().to(device=DEVICE)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load CIFAR-10 dataset. Downloading will be started automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], # magic numbers\n",
    "                         std=[0.229, 0.224, 0.225]) # 3 channels\n",
    "])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = F.cross_entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the optimizer. You can **choose other** optimizers as you like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model_eval, loader_eval, criterion_eval):\n",
    "    model_eval.eval()\n",
    "    loss_eval = 0\n",
    "    correct = 0.\n",
    "    pbar = tqdm(total = len(loader_eval), desc='Evaluation', ncols=100)\n",
    "    with torch.no_grad():\n",
    "        for data, target in loader_eval:\n",
    "            data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "            output = model(data)\n",
    "            loss_eval += criterion_eval(output, target).item()\n",
    "\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "            pbar.update(1)\n",
    "    pbar.close()\n",
    "\n",
    "    loss_eval = loss_eval / loader_eval.dataset.__len__()\n",
    "    accuracy = correct / loader_eval.dataset.__len__()\n",
    "    response = {'loss': loss_eval, 'acc': accuracy}\n",
    "    return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main training function defined below. Be careful of [overfitting](https://en.wikipedia.org/wiki/Overfitting)!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc = np.zeros(NUM_EPOCHS)\n",
    "eval_acc = np.zeros(NUM_EPOCHS)\n",
    "train_loss = np.zeros(NUM_EPOCHS)\n",
    "eval_loss = np.zeros(NUM_EPOCHS)\n",
    "\n",
    "model.train()\n",
    "for epoch_idx in range(NUM_EPOCHS):\n",
    "    pbar = tqdm(total = len(train_dataloader), desc='Train - Epoch {}'.format(epoch_idx), ncols=100)\n",
    "    for batch_idx, (data, target) in enumerate(train_dataloader):\n",
    "        data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pbar.update(1)\n",
    "    pbar.close()\n",
    "\n",
    "    train_resp = evaluate(model, train_dataloader, criterion)\n",
    "    eval_resp = evaluate(model, test_dataloader, criterion)\n",
    "\n",
    "    print ('-*-*-*-*-*- Epoch {} -*-*-*-*-*-'.format(epoch_idx))\n",
    "    print ('Train Loss: {:.6f}\\t'.format(train_resp['loss']))\n",
    "    print ('Train Acc: {:.6f}\\t'.format(train_resp['acc']))\n",
    "    print ('Eval Loss: {:.6f}\\t'.format(eval_resp['loss']))\n",
    "    print ('Eval Acc: {:.6f}\\t'.format(eval_resp['acc']))\n",
    "    print ('\\n')\n",
    "\n",
    "    train_acc[epoch_idx] = train_resp['acc']\n",
    "    eval_acc[epoch_idx] = eval_resp['acc']\n",
    "    train_loss[epoch_idx] = train_resp['loss']\n",
    "    eval_loss[epoch_idx] = eval_resp['loss']\n",
    "\n",
    "    # save model and training data\n",
    "    torch.save(model, 'simple-cnn.pth'.format(CONFIG[\"network\"]))\n",
    "    np.savez('simple-cnn',config=CONFIG,train_acc=train_acc, eval_acc=eval_acc,train_loss=train_loss, eval_loss=eval_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualize the training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "train_acc = np.zeros(NUM_EPOCHS)\n",
    "eval_acc = np.zeros(NUM_EPOCHS)\n",
    "train_loss = np.zeros(NUM_EPOCHS)\n",
    "eval_loss = np.zeros(NUM_EPOCHS)\n",
    "\n",
    "def plot_acc(train_acc,eval_acc):\n",
    "    plt.plot(train_acc,label=\"Train\")\n",
    "    plt.plot(eval_acc,label=\"Test\")\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Accuracy\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def plot_loss(train_loss,eval_loss):\n",
    "    plt.plot(train_loss,label=\"Train\")\n",
    "    plt.plot(eval_loss,label=\"Test\")\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Loss\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    \n",
    "plot_acc(train_acc,eval_acc)\n",
    "plot_loss(train_loss,eval_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
